/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.util;



import java.io.IOException;

import java.nio.ByteBuffer;

import java.nio.channels.WritableByteChannel;

import java.util.ArrayList;



/**

 * Implements a WritableByteChannel that may contain multiple output shards.

 *

 * <p>This provides {@link #writeToShard}, which takes a shard number for writing to a particular

 * shard.

 *

 * <p>The channel is considered open if all downstream channels are open, and closes all downstream

 * channels when closed.

 */

public class ShardingWritableByteChannel implements WritableByteChannel {



  /** Special shard number that causes a write to all shards. */

  public static final int ALL_SHARDS = -2;



  private final ArrayList<WritableByteChannel> writers = new ArrayList<>();



  /** Returns the number of output shards. */

  public int getNumShards() {

    return writers.size();

  }



  /** Adds another shard output channel. */

  public void addChannel(WritableByteChannel writer) {

    writers.add(writer);

  }



  /** Returns the WritableByteChannel associated with the given shard number. */

  public WritableByteChannel getChannel(int shardNum) {

    return writers.get(shardNum);

  }



  /**

   * Writes the buffer to the given shard.

   *

   * <p>This does not change the current output shard.

   *

   * @return The total number of bytes written. If the shard number is {@link #ALL_SHARDS}, then the

   *     total is the sum of each individual shard write.

   */

  public int writeToShard(int shardNum, ByteBuffer src) throws IOException {

    if (shardNum >= 0) {

      return writers.get(shardNum).write(src);

    }



    switch (shardNum) {

      case ALL_SHARDS:

        int size = 0;

        for (WritableByteChannel writer : writers) {

          size += writer.write(src);

        }

        return size;



      default:

        throw new IllegalArgumentException("Illegal shard number: " + shardNum);

    }

  }



  /**

   * Writes a buffer to all shards.

   *

   * <p>Same as calling {@code writeToShard(ALL_SHARDS, buf)}.

   */

  @Override

  public int write(ByteBuffer src) throws IOException {

    return writeToShard(ALL_SHARDS, src);

  }



  @Override

  public boolean isOpen() {

    for (WritableByteChannel writer : writers) {

      if (!writer.isOpen()) {

        return false;

      }

    }



    return true;

  }



  @Override

  public void close() throws IOException {

    for (WritableByteChannel writer : writers) {

      writer.close();

    }

  }

}